#!/usr/bin/env python3
"""
Phase 4: Threshold & Spline Tests
Tables 10-13: Governance knots, KAOPEN knots, time-since-independence, rolling windows.
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS

CCA_DIR = PROJECT_DIR / "cca_tipping"
PROCESSED_DIR = CCA_DIR / "data" / "processed"
TABLE_DIR = CCA_DIR / "output" / "tables"
FIG_DIR = CCA_DIR / "output" / "figures"
TABLE_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

# ── Helpers ────────────────────────────────────────────────────────────────
def star(p):
    if p < 0.01: return "***"
    elif p < 0.05: return "**"
    elif p < 0.10: return "*"
    return ""

def run_gls(df, dep_var, indep_vars):
    comp = df.dropna(subset=[dep_var] + indep_vars).copy()
    if len(comp) < 50:
        return None
    try:
        y = comp[dep_var].values
        X = comp[indep_vars].values
        gls = PanelGLS()
        gls.fit(y, X, comp['iso3'].values, comp['year'].values)
        return {
            'r_squared': gls.r_squared,
            'n_obs': gls.n_obs,
            'n_countries': gls.n_countries,
            'coefficients': dict(zip(indep_vars, gls.beta)),
            'std_errors': dict(zip(indep_vars, gls.se)),
            'p_values': dict(zip(indep_vars, gls.pvalues)),
        }
    except Exception as e:
        print(f"    GLS error: {e}")
        return None

# ── Load Data ──────────────────────────────────────────────────────────────
print("=" * 70)
print("PHASE 4: THRESHOLD & SPLINE TESTS")
print("=" * 70)

panel = pd.read_csv(PROCESSED_DIR / "cca_panel.csv")
est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()

demo_vars = ['Z_1', 'Z_2', 'Z_3']
baseline_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                     'nfa_gdp_lag', 'log_rel_opw', 'health_exp_gdp']
base_vars = demo_vars + baseline_controls

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 10: GOVERNANCE SPLINE — Z₁ BELOW vs ABOVE KNOT
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 10: GOVERNANCE SPLINE (Z₁ × governance)")
print("=" * 70)

gov_sample = est.dropna(subset=base_vars + ['ca_gdp', 'governance_composite']).copy()

# Test knots at governance percentiles
knot_percentiles = [10, 20, 25, 30, 40, 50, 60, 75]
gov_knots = {p: np.percentile(gov_sample['governance_composite'].dropna(), p)
             for p in knot_percentiles}

spline_rows = []
for pctl, knot in gov_knots.items():
    # Spline: Z₁ × min(gov, knot) and Z₁ × max(0, gov - knot)
    gov_sample[f'Z1_gov_below'] = gov_sample['Z_1'] * gov_sample['governance_composite'].clip(upper=knot)
    gov_sample[f'Z1_gov_above'] = gov_sample['Z_1'] * (gov_sample['governance_composite'] - knot).clip(lower=0)

    spline_vars = ['Z1_gov_below', 'Z1_gov_above'] + ['Z_2', 'Z_3'] + baseline_controls
    r = run_gls(gov_sample, 'ca_gdp', spline_vars)
    if r is None:
        continue

    row = {
        'knot_percentile': pctl,
        'knot_value': knot,
        'Z1_below_coef': r['coefficients']['Z1_gov_below'],
        'Z1_below_se': r['std_errors']['Z1_gov_below'],
        'Z1_below_pval': r['p_values']['Z1_gov_below'],
        'Z1_above_coef': r['coefficients']['Z1_gov_above'],
        'Z1_above_se': r['std_errors']['Z1_gov_above'],
        'Z1_above_pval': r['p_values']['Z1_gov_above'],
        'r_squared': r['r_squared'],
        'n_obs': r['n_obs'],
        'n_countries': r['n_countries'],
        'diff': r['coefficients']['Z1_gov_above'] - r['coefficients']['Z1_gov_below'],
    }
    spline_rows.append(row)
    print(f"  Knot at p{pctl} (gov={knot:.2f}): below={row['Z1_below_coef']:.2f}{star(row['Z1_below_pval'])}, "
          f"above={row['Z1_above_coef']:.2f}{star(row['Z1_above_pval'])}, R²={row['r_squared']:.4f}")

spline_df = pd.DataFrame(spline_rows)
spline_df.to_csv(TABLE_DIR / "table10_governance_spline.csv", index=False)

# Find best R²
if len(spline_df) > 0:
    best = spline_df.loc[spline_df['r_squared'].idxmax()]
    print(f"\n  Best knot: p{best['knot_percentile']:.0f} (gov={best['knot_value']:.2f}), R²={best['r_squared']:.4f}")

md = ["# Table 10: Governance Spline — Z₁ Effect Below vs Above Knot\n"]
md.append("| Knot (pctl) | Gov value | Z₁ below | SE | p-val | Z₁ above | SE | p-val | R² |")
md.append("|---|---|---|---|---|---|---|---|---|")
for _, row in spline_df.iterrows():
    md.append(f"| p{row['knot_percentile']:.0f} | {row['knot_value']:.2f} | "
              f"{row['Z1_below_coef']:.2f}{star(row['Z1_below_pval'])} | {row['Z1_below_se']:.2f} | "
              f"{row['Z1_below_pval']:.4f} | {row['Z1_above_coef']:.2f}{star(row['Z1_above_pval'])} | "
              f"{row['Z1_above_se']:.2f} | {row['Z1_above_pval']:.4f} | {row['r_squared']:.4f} |")
(TABLE_DIR / "table10_governance_spline.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 11: KAOPEN SPLINE
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 11: KAOPEN SPLINE")
print("=" * 70)

kaopen_sample = est.dropna(subset=base_vars + ['ca_gdp']).copy()
kaopen_pctls = [25, 50, 75]
kaopen_knots = {p: np.percentile(kaopen_sample['kaopen'].dropna(), p) for p in kaopen_pctls}

kaopen_rows = []
for pctl, knot in kaopen_knots.items():
    kaopen_sample[f'Z1_kaopen_below'] = kaopen_sample['Z_1'] * kaopen_sample['kaopen'].clip(upper=knot)
    kaopen_sample[f'Z1_kaopen_above'] = kaopen_sample['Z_1'] * (kaopen_sample['kaopen'] - knot).clip(lower=0)

    controls_no_kaopen = [v for v in baseline_controls if v != 'kaopen']
    spline_vars = ['Z1_kaopen_below', 'Z1_kaopen_above', 'Z_2', 'Z_3'] + controls_no_kaopen
    r = run_gls(kaopen_sample, 'ca_gdp', spline_vars)
    if r is None:
        continue

    row = {
        'knot_percentile': pctl,
        'knot_value': knot,
        'Z1_below_coef': r['coefficients']['Z1_kaopen_below'],
        'Z1_below_pval': r['p_values']['Z1_kaopen_below'],
        'Z1_above_coef': r['coefficients']['Z1_kaopen_above'],
        'Z1_above_pval': r['p_values']['Z1_kaopen_above'],
        'r_squared': r['r_squared'],
        'n_obs': r['n_obs'],
    }
    kaopen_rows.append(row)
    print(f"  Knot at p{pctl} (kaopen={knot:.2f}): below={row['Z1_below_coef']:.2f}{star(row['Z1_below_pval'])}, "
          f"above={row['Z1_above_coef']:.2f}{star(row['Z1_above_pval'])}")

kaopen_df = pd.DataFrame(kaopen_rows)
kaopen_df.to_csv(TABLE_DIR / "table11_kaopen_spline.csv", index=False)

md = ["# Table 11: KAOPEN Spline — Z₁ Effect Below vs Above Knot\n"]
md.append("| Knot (pctl) | KAOPEN value | Z₁ below | p-val | Z₁ above | p-val | R² |")
md.append("|---|---|---|---|---|---|---|")
for _, row in kaopen_df.iterrows():
    md.append(f"| p{row['knot_percentile']:.0f} | {row['knot_value']:.2f} | "
              f"{row['Z1_below_coef']:.2f}{star(row['Z1_below_pval'])} | {row['Z1_below_pval']:.4f} | "
              f"{row['Z1_above_coef']:.2f}{star(row['Z1_above_pval'])} | {row['Z1_above_pval']:.4f} | "
              f"{row['r_squared']:.4f} |")
(TABLE_DIR / "table11_kaopen_spline.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 12: TIME-SINCE-INDEPENDENCE SPLINE
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 12: TIME-SINCE-INDEPENDENCE")
print("=" * 70)

# Only transition countries have years_since_indep
trans_sample = est[est['is_transition'] == 1].dropna(subset=base_vars + ['ca_gdp', 'years_since_indep']).copy()
print(f"  Transition sample: {trans_sample['iso3'].nunique()} countries, {len(trans_sample):,} obs")

# Baseline on transition countries only
r_trans_base = run_gls(trans_sample, 'ca_gdp', base_vars)
if r_trans_base:
    print(f"  Transition baseline: Z₁={r_trans_base['coefficients']['Z_1']:.2f}"
          f"{star(r_trans_base['p_values']['Z_1'])}")

# Z₁ × years_since_indep interaction
trans_sample['ysi_dm'] = trans_sample['years_since_indep'] - trans_sample['years_since_indep'].mean()
trans_sample['Z_1_x_ysi'] = trans_sample['Z_1'] * trans_sample['ysi_dm']

ysi_vars = base_vars + ['ysi_dm', 'Z_1_x_ysi']
r_ysi = run_gls(trans_sample, 'ca_gdp', ysi_vars)
if r_ysi:
    print(f"  + Z₁×years_since_indep: Z₁={r_ysi['coefficients']['Z_1']:.2f}{star(r_ysi['p_values']['Z_1'])}, "
          f"Z₁×ysi={r_ysi['coefficients']['Z_1_x_ysi']:.3f}{star(r_ysi['p_values']['Z_1_x_ysi'])}")

# Spline at 10 and 20 years since independence
for knot_yr in [10, 15, 20]:
    trans_sample[f'Z1_ysi_below_{knot_yr}'] = trans_sample['Z_1'] * trans_sample['years_since_indep'].clip(upper=knot_yr)
    trans_sample[f'Z1_ysi_above_{knot_yr}'] = trans_sample['Z_1'] * (trans_sample['years_since_indep'] - knot_yr).clip(lower=0)

    ysi_spline_vars = [f'Z1_ysi_below_{knot_yr}', f'Z1_ysi_above_{knot_yr}', 'Z_2', 'Z_3'] + baseline_controls
    r_sp = run_gls(trans_sample, 'ca_gdp', ysi_spline_vars)
    if r_sp:
        print(f"  Knot at {knot_yr} years: below={r_sp['coefficients'][f'Z1_ysi_below_{knot_yr}']:.3f}"
              f"{star(r_sp['p_values'][f'Z1_ysi_below_{knot_yr}'])}, "
              f"above={r_sp['coefficients'][f'Z1_ysi_above_{knot_yr}']:.3f}"
              f"{star(r_sp['p_values'][f'Z1_ysi_above_{knot_yr}'])}")

# Save
ysi_results = []
if r_trans_base:
    ysi_results.append({'spec': 'Transition baseline', 'Z_1_coef': r_trans_base['coefficients']['Z_1'],
                        'Z_1_pval': r_trans_base['p_values']['Z_1'], 'r_squared': r_trans_base['r_squared'],
                        'n_obs': r_trans_base['n_obs']})
if r_ysi:
    ysi_results.append({'spec': '+ Z₁×years_since_indep', 'Z_1_coef': r_ysi['coefficients']['Z_1'],
                        'Z_1_pval': r_ysi['p_values']['Z_1'], 'r_squared': r_ysi['r_squared'],
                        'n_obs': r_ysi['n_obs'],
                        'Z_1_x_ysi': r_ysi['coefficients']['Z_1_x_ysi'],
                        'Z_1_x_ysi_pval': r_ysi['p_values']['Z_1_x_ysi']})

pd.DataFrame(ysi_results).to_csv(TABLE_DIR / "table12_time_since_independence.csv", index=False)

md = ["# Table 12: Time-Since-Independence (Transition Countries Only)\n"]
md.append("| Specification | Z₁ | p-val | Z₁×YSI | p-val | R² | N |")
md.append("|---|---|---|---|---|---|---|")
for r in ysi_results:
    ysi_c = f"{r.get('Z_1_x_ysi', 0):.3f}" if 'Z_1_x_ysi' in r else "—"
    ysi_p = f"{r.get('Z_1_x_ysi_pval', 0):.4f}" if 'Z_1_x_ysi_pval' in r else "—"
    md.append(f"| {r['spec']} | {r['Z_1_coef']:.2f}{star(r['Z_1_pval'])} | {r['Z_1_pval']:.4f} | "
              f"{ysi_c} | {ysi_p} | {r['r_squared']:.4f} | {r['n_obs']} |")
(TABLE_DIR / "table12_time_since_independence.md").write_text("\n".join(md))

# ═══════════════════════════════════════════════════════════════════════════
# TABLE 13: ROLLING 15-YEAR WINDOWS
# ═══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TABLE 13: ROLLING WINDOWS (15-year)")
print("=" * 70)

full_sample = est.dropna(subset=base_vars + ['ca_gdp']).copy()
window_size = 15

window_rows = []
for start_year in range(1986, 2024 - window_size + 2):
    end_year = start_year + window_size - 1
    win = full_sample[(full_sample['year'] >= start_year) & (full_sample['year'] <= end_year)]

    if len(win) < 100:
        continue

    r = run_gls(win, 'ca_gdp', base_vars)
    if r is None:
        continue

    row = {
        'window_start': start_year,
        'window_end': end_year,
        'mid_year': start_year + window_size // 2,
        'n_countries': r['n_countries'],
        'n_obs': r['n_obs'],
        'r_squared': r['r_squared'],
    }
    for v in demo_vars:
        row[f'{v}_coef'] = r['coefficients'][v]
        row[f'{v}_se'] = r['std_errors'][v]
        row[f'{v}_pval'] = r['p_values'][v]
    window_rows.append(row)

window_df = pd.DataFrame(window_rows)
window_df.to_csv(TABLE_DIR / "table13_rolling_windows.csv", index=False)

# Print summary
for _, row in window_df.iterrows():
    sig = star(row['Z_1_pval'])
    print(f"  {row['window_start']:.0f}-{row['window_end']:.0f}: Z₁={row['Z_1_coef']:7.2f}{sig} "
          f"(p={row['Z_1_pval']:.4f}), N_c={row['n_countries']}")

# Rolling window figure
if len(window_df) > 2:
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(window_df['mid_year'], window_df['Z_1_coef'], 'b-o', markersize=4, label='Z₁ coefficient')
    ax.fill_between(window_df['mid_year'],
                    window_df['Z_1_coef'] - 1.96 * window_df['Z_1_se'],
                    window_df['Z_1_coef'] + 1.96 * window_df['Z_1_se'],
                    alpha=0.2, color='b')
    ax.axhline(0, color='k', linestyle='--', linewidth=0.5)
    ax.set_xlabel("Window midpoint")
    ax.set_ylabel("Z₁ coefficient on CA/GDP")
    ax.set_title("Rolling 15-Year Window: Z₁ Coefficient Evolution")
    ax.legend()
    fig.tight_layout()
    fig.savefig(FIG_DIR / "fig1_rolling_z1.png", dpi=150, bbox_inches="tight")
    plt.close()
    print(f"\n  Saved figure: {FIG_DIR / 'fig1_rolling_z1.png'}")

print(f"\nAll Phase 4 tables saved to {TABLE_DIR}")
print("Phase 4 complete.")
